{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 评分预测算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入包\n",
    "import random\n",
    "import math\n",
    "import time\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一. 通用函数定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义装饰器，监控运行时间\n",
    "def timmer(func):\n",
    "    def wrapper(*args, **kwargs):\n",
    "        start_time = time.time()\n",
    "        res = func(*args, **kwargs)\n",
    "        stop_time = time.time()\n",
    "        print('Func %s, run time: %s' % (func.__name__, stop_time - start_time))\n",
    "        return res\n",
    "    return wrapper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 数据处理相关\n",
    "1. load data\n",
    "2. split data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Data():\n",
    "    \n",
    "    def __init__(self, user, item, rate, test=False, predict=0.0):\n",
    "        self.user = user\n",
    "        self.item = item\n",
    "        self.rate = rate\n",
    "        self.test = test\n",
    "        self.predict = predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dataset():\n",
    "    \n",
    "    def __init__(self, fp):\n",
    "        # fp: data file path\n",
    "        self.data = self.loadData(fp)\n",
    "    \n",
    "    def loadData(self, fp):\n",
    "        data = []\n",
    "        for l in open(fp):\n",
    "            data.append(tuple(map(int, l.strip().split('::')[:3])))\n",
    "        data = [Data(*d) for d in data]\n",
    "        return data\n",
    "    \n",
    "    def splitData(self, M, k, seed=1):\n",
    "        '''\n",
    "        :params: data, 加载的所有数据条目\n",
    "        :params: M, 划分的数目，最后需要取M折的平均\n",
    "        :params: k, 本次是第几次划分，k~[0, M)\n",
    "        :params: seed, random的种子数，对于不同的k应设置成一样的\n",
    "        :return: train, test\n",
    "        '''\n",
    "        random.seed(seed)\n",
    "        for i in range(len(self.data)):\n",
    "            # 这里与书中的不一致，本人认为取M-1较为合理，因randint是左右都覆盖的\n",
    "            if random.randint(0, M-1) == k:\n",
    "                self.data[i].test = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 评价指标\n",
    "RMSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def RMSE(records):\n",
    "    rmse = {'train_rmse': [], 'test_rmse': []}\n",
    "    for r in records:\n",
    "        if r.test: \n",
    "            rmse['test_rmse'].append((r.rate - r.predict) ** 2)\n",
    "        else:\n",
    "            rmse['train_rmse'].append((r.rate - r.predict) ** 2)\n",
    "    rmse = {'train_rmse': math.sqrt(sum(rmse['train_rmse']) / len(rmse['train_rmse'])),\n",
    "            'test_rmse': math.sqrt(sum(rmse['test_rmse']) / len(rmse['test_rmse']))}\n",
    "    return rmse"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 二. 算法实现\n",
    "主要是不同的分组算法\n",
    "1. Cluster\n",
    "2. IdCluster\n",
    "3. UserActivityCluster\n",
    "4. ItemPopularityCluster\n",
    "5. UserVoteCluster\n",
    "6. ItemVoteCluster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Cluster\n",
    "class Cluster:\n",
    "    \n",
    "    def __init__(self, records):\n",
    "        self.group = {}\n",
    "        \n",
    "    def GetGroup(self, i):\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. IdCluster\n",
    "class IdCluster(Cluster):\n",
    "    \n",
    "    def __init__(self, records):\n",
    "        Cluster.__init__(self, records)\n",
    "        \n",
    "    def GetGroup(self, i):\n",
    "        return i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. UserActivityCluster\n",
    "class UserActivityCluster(Cluster):\n",
    "    \n",
    "    def __init__(self, records):\n",
    "        Cluster.__init__(self, records)\n",
    "        activity = {}\n",
    "        for r in records:\n",
    "            if r.test: continue\n",
    "            if r.user not in activity:\n",
    "                activity[r.user] = 0\n",
    "            activity[r.user] += 1\n",
    "        # 按照用户活跃度进行分组\n",
    "        k = 0\n",
    "        for user, n in sorted(activity.items(), key=lambda x: x[-1], reverse=False):\n",
    "            c = int((k * 5) / len(activity))\n",
    "            self.group[user] = c\n",
    "            k += 1\n",
    "        \n",
    "    def GetGroup(self, uid):\n",
    "        if uid not in self.group:\n",
    "            return -1\n",
    "        else:\n",
    "            return self.group[uid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. ItemPopularityCluster\n",
    "class ItemPopularityCluster(Cluster):\n",
    "    \n",
    "    def __init__(self, records):\n",
    "        Cluster.__init__(self, records)\n",
    "        popularity = {}\n",
    "        for r in records:\n",
    "            if r.test: continue\n",
    "            if r.item not in popularity:\n",
    "                popularity[r.item] = 0\n",
    "            popularity[r.item] += 1\n",
    "        # 按照物品流行度进行分组\n",
    "        k = 0\n",
    "        for item, n in sorted(popularity.items(), key=lambda x: x[-1], reverse=False):\n",
    "            c = int((k * 5) / len(popularity))\n",
    "            self.group[item] = c\n",
    "            k += 1\n",
    "        \n",
    "    def GetGroup(self, iid):\n",
    "        if iid not in self.group:\n",
    "            return -1\n",
    "        else:\n",
    "            return self.group[iid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. UserVoteCluster\n",
    "class UserVoteCluster(Cluster):\n",
    "    \n",
    "    def __init__(self, records):\n",
    "        Cluster.__init__(self, records)\n",
    "        vote, cnt = {}, {}\n",
    "        for r in records:\n",
    "            if r.test: continue\n",
    "            if r.user not in vote:\n",
    "                vote[r.user] = 0\n",
    "                cnt[r.user] = 0\n",
    "            vote[r.user] += r.rate\n",
    "            cnt[r.user] += 1\n",
    "        # 按照物品平均评分进行分组\n",
    "        for user, v in vote.items():\n",
    "            c = v / (cnt[user] * 1.0)\n",
    "            self.group[user] = int(c * 2)\n",
    "        \n",
    "    def GetGroup(self, uid):\n",
    "        if uid not in self.group:\n",
    "            return -1\n",
    "        else:\n",
    "            return self.group[uid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. ItemVoteCluster\n",
    "class ItemVoteCluster(Cluster):\n",
    "    \n",
    "    def __init__(self, records):\n",
    "        Cluster.__init__(self, records)\n",
    "        vote, cnt = {}, {}\n",
    "        for r in records:\n",
    "            if r.test: continue\n",
    "            if r.item not in vote:\n",
    "                vote[r.item] = 0\n",
    "                cnt[r.item] = 0\n",
    "            vote[r.item] += r.rate\n",
    "            cnt[r.item] += 1\n",
    "        # 按照物品平均评分进行分组\n",
    "        for item, v in vote.items():\n",
    "            c = v / (cnt[item] * 1.0)\n",
    "            self.group[item] = int(c * 2)\n",
    "        \n",
    "    def GetGroup(self, iid):\n",
    "        if iid not in self.group:\n",
    "            return -1\n",
    "        else:\n",
    "            return self.group[iid]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 返回预测接口函数\n",
    "def PredictAll(records, UserGroup, ItemGroup):\n",
    "    '''\n",
    "    :params: records, 数据集\n",
    "    :params: UserGroup, 用户分组类\n",
    "    :params: ItemGroup, 物品分组类\n",
    "    '''\n",
    "    userGroup = UserGroup(records)\n",
    "    itemGroup = ItemGroup(records)\n",
    "    group = {}\n",
    "    for r in records:\n",
    "        ug = userGroup.GetGroup(r.user)\n",
    "        ig = itemGroup.GetGroup(r.item)\n",
    "        if ug not in group:\n",
    "            group[ug] = {}\n",
    "        if ig not in group[ug]:\n",
    "            group[ug][ig] = []\n",
    "        # 这里计算的残差\n",
    "        group[ug][ig].append(r.rate - r.predict)\n",
    "    for ug in group:\n",
    "        for ig in group[ug]:\n",
    "            group[ug][ig] = sum(group[ug][ig]) / (1.0 * len(group[ug][ig]) + 1.0)\n",
    "    # predict\n",
    "    for i in range(len(records)):\n",
    "        ug = userGroup.GetGroup(records[i].user)\n",
    "        ig = itemGroup.GetGroup(records[i].item)\n",
    "        # 这里需要与之前的结果进行结合\n",
    "        records[i].predict += group[ug][ig]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 三. 实验\n",
    "\n",
    "UserGroup = [Cluster, IdCluster, Cluster, UserActivityCluster, Cluster, UserVoteCluster, Cluster, UserVoteCluster, UserActivityCluster]\n",
    "\n",
    "ItemGroup = [Cluster, Cluster, IdCluster, Cluster, ItemPopularityCluster, Cluster, ItemVoteCluster, ItemVoteCluster, ItemPopularityCluster]\n",
    "\n",
    "M=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Experiment():\n",
    "    \n",
    "    def __init__(self, M, UserGroup, ItemGroup, fp='../dataset/ml-1m/ratings.dat'):\n",
    "        '''\n",
    "        :params: M, 进行多少次实验\n",
    "        :params: UserGroup, ItemGroup, 聚类算法类型\n",
    "        :params: fp, 数据文件路径\n",
    "        '''\n",
    "        self.userGroup = UserGroup\n",
    "        self.itemGroup = ItemGroup\n",
    "        self.dataset = Dataset(fp)\n",
    "        self.dataset.splitData(M, 0)\n",
    "    \n",
    "    # 定义单次实验\n",
    "    def worker(self, records):\n",
    "        '''\n",
    "        :params: train, 训练数据集\n",
    "        :params: test, 测试数据集\n",
    "        :return: train和test的rmse值\n",
    "        '''\n",
    "        PredictAll(records, self.userGroup, self.itemGroup)\n",
    "        metric = RMSE(records)\n",
    "        return metric\n",
    "    \n",
    "    # 多次实验取平均\n",
    "    def run(self):\n",
    "        metrics = {'train_rmse': 0, 'test_rmse': 0}\n",
    "        metric = self.worker(self.dataset.data)\n",
    "        print('Result (UserGroup={}, ItemGroup={}): {}'.format(\\\n",
    "               self.userGroup.__name__, \\\n",
    "               self.itemGroup.__name__, metric))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Result (UserGroup=Cluster, ItemGroup=Cluster): {'train_rmse': 1.1163419044799097, 'test_rmse': 1.1239098429180596}\n",
      "Result (UserGroup=IdCluster, ItemGroup=Cluster): {'train_rmse': 1.0279720784446358, 'test_rmse': 1.0328163104921269}\n",
      "Result (UserGroup=Cluster, ItemGroup=IdCluster): {'train_rmse': 0.9187033847868101, 'test_rmse': 0.921579422992697}\n",
      "Result (UserGroup=UserActivityCluster, ItemGroup=Cluster): {'train_rmse': 0.9165409575497123, 'test_rmse': 0.919459056096504}\n",
      "Result (UserGroup=UserActivityCluster, ItemGroup=IdCluster): {'train_rmse': 0.9077767398638859, 'test_rmse': 0.9107730379036226}\n",
      "Result (UserGroup=Cluster, ItemGroup=ItemPopularityCluster): {'train_rmse': 0.9077761064444031, 'test_rmse': 0.9107750684983894}\n",
      "Result (UserGroup=IdCluster, ItemGroup=ItemPopularityCluster): {'train_rmse': 0.8792394386766903, 'test_rmse': 0.8829060808248217}\n",
      "Result (UserGroup=UserActivityCluster, ItemGroup=ItemPopularityCluster): {'train_rmse': 0.8792392554915059, 'test_rmse': 0.8828976812104593}\n",
      "Result (UserGroup=UserVoteCluster, ItemGroup=Cluster): {'train_rmse': 0.8792392496838647, 'test_rmse': 0.8828916004055563}\n",
      "Result (UserGroup=UserVoteCluster, ItemGroup=IdCluster): {'train_rmse': 0.8695289142775611, 'test_rmse': 0.8730504482678775}\n",
      "Result (UserGroup=Cluster, ItemGroup=ItemVoteCluster): {'train_rmse': 0.8695285755303085, 'test_rmse': 0.8730516076632375}\n",
      "Result (UserGroup=IdCluster, ItemGroup=ItemVoteCluster): {'train_rmse': 0.8479002323916044, 'test_rmse': 0.8510565056391749}\n",
      "Result (UserGroup=UserVoteCluster, ItemGroup=ItemVoteCluster): {'train_rmse': 0.8478967181454204, 'test_rmse': 0.8510569122189332}\n"
     ]
    }
   ],
   "source": [
    "UserGroups = [Cluster, IdCluster, Cluster, UserActivityCluster, UserActivityCluster, Cluster, IdCluster, UserActivityCluster, UserVoteCluster, UserVoteCluster, Cluster, IdCluster, UserVoteCluster]\n",
    "ItemGroups = [Cluster, Cluster, IdCluster, Cluster, IdCluster, ItemPopularityCluster, ItemPopularityCluster, ItemPopularityCluster, Cluster, IdCluster, ItemVoteCluster, ItemVoteCluster, ItemVoteCluster]\n",
    "M = 10\n",
    "exp = Experiment(M, None, None)\n",
    "for i in range(len(UserGroups)):\n",
    "    exp.userGroup = UserGroups[i]\n",
    "    exp.itemGroup = ItemGroups[i]\n",
    "    exp.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四. 实验结果\n",
    "\n",
    "效果比书中的要好\n",
    "\n",
    "Result (UserGroup=Cluster, ItemGroup=Cluster): {'train_rmse': 1.1163419044799097, 'test_rmse': 1.1239098429180596}\n",
    "\n",
    "Result (UserGroup=IdCluster, ItemGroup=Cluster): {'train_rmse': 1.0279720784446358, 'test_rmse': 1.0328163104921269}\n",
    "\n",
    "Result (UserGroup=Cluster, ItemGroup=IdCluster): {'train_rmse': 0.9187033847868101, 'test_rmse': 0.921579422992697}\n",
    "\n",
    "Result (UserGroup=UserActivityCluster, ItemGroup=Cluster): {'train_rmse': 0.9165409575497123, 'test_rmse': 0.919459056096504}\n",
    "\n",
    "Result (UserGroup=UserActivityCluster, ItemGroup=IdCluster): {'train_rmse': 0.9077767398638859, 'test_rmse': 0.9107730379036226}\n",
    "\n",
    "Result (UserGroup=Cluster, ItemGroup=ItemPopularityCluster): {'train_rmse': 0.9077761064444031, 'test_rmse': 0.9107750684983894}\n",
    "\n",
    "Result (UserGroup=IdCluster, ItemGroup=ItemPopularityCluster): {'train_rmse': 0.8792394386766903, 'test_rmse': 0.8829060808248217}\n",
    "\n",
    "Result (UserGroup=UserActivityCluster, ItemGroup=ItemPopularityCluster): {'train_rmse': 0.8792392554915059, 'test_rmse': 0.8828976812104593}\n",
    "\n",
    "Result (UserGroup=UserVoteCluster, ItemGroup=Cluster): {'train_rmse': 0.8792392496838647, 'test_rmse': 0.8828916004055563}\n",
    "\n",
    "Result (UserGroup=UserVoteCluster, ItemGroup=IdCluster): {'train_rmse': 0.8695289142775611, 'test_rmse': 0.8730504482678775}\n",
    "\n",
    "Result (UserGroup=Cluster, ItemGroup=ItemVoteCluster): {'train_rmse': 0.8695285755303085, 'test_rmse': 0.8730516076632375}\n",
    "\n",
    "Result (UserGroup=IdCluster, ItemGroup=ItemVoteCluster): {'train_rmse': 0.8479002323916044, 'test_rmse': 0.8510565056391749}\n",
    "\n",
    "Result (UserGroup=UserVoteCluster, ItemGroup=ItemVoteCluster): {'train_rmse': 0.8478967181454204, 'test_rmse': 0.8510569122189332}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 五. 总结\n",
    "1. 将Cluster定义成类，并提供接口\n",
    "2. 将聚类与预测分开实现，分别提供接口\n",
    "3. 将数据存储成类的形式，方便残差学习"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
